{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import namedtuple\n",
    "LayerParams = namedtuple(\n",
    "    \"LayerParams\",\n",
    "    [\n",
    "        \"percent_on_k_winner\",\n",
    "        \"boost_strength\",\n",
    "        \"boost_strength_factor\",\n",
    "        \"k_inference_factor\",\n",
    "        \"local\",\n",
    "        \"weights_density\",\n",
    "    ],\n",
    "    defaults=[0.25, 1.4, 0.7, 1.0, False, 0.5],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "lp = LayerParams(boost_strength=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0.25, 0.5, 0.7, 1.0, False, 0.5)\n"
     ]
    }
   ],
   "source": [
    "def prt(*args):\n",
    "    print(args)\n",
    "prt(*lp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import combinations, permutations, combinations_with_replacement, product"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0, 0)\n",
      "(0, 1)\n",
      "(0, 2)\n",
      "(0, 3)\n",
      "(0, 4)\n",
      "(1, 0)\n",
      "(1, 1)\n",
      "(1, 2)\n",
      "(1, 3)\n",
      "(1, 4)\n",
      "(2, 0)\n",
      "(2, 1)\n",
      "(2, 2)\n",
      "(2, 3)\n",
      "(2, 4)\n",
      "(3, 0)\n",
      "(3, 1)\n",
      "(3, 2)\n",
      "(3, 3)\n",
      "(3, 4)\n",
      "(4, 0)\n",
      "(4, 1)\n",
      "(4, 2)\n",
      "(4, 3)\n",
      "(4, 4)\n"
     ]
    }
   ],
   "source": [
    "for idx in  product(range(5), range(5)):\n",
    "    print(idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "t = torch.rand(5,5)\n",
    "q = torch.ones(5,5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.1917)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[True, True, True, True, True],\n",
       "        [True, True, True, True, True],\n",
       "        [True, True, True, True, True],\n",
       "        [True, True, True, True, True],\n",
       "        [True, True, True, True, True]])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.eq(t,q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.allclose(t,q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "lowest_25_hebb = [(4,0), (2,1), (0,1)]\n",
    "lowest_50_mag = [(2,0), (2,2), (4,2), (0,1), (3,2), (4,1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{(0, 1)}"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "set(lowest_25_hebb).intersection(lowest_50_mag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "corr = torch.tensor(\n",
    "    [\n",
    "        [0.3201, 0.8318, 0.3382, 0.9734, 0.0985],\n",
    "        [0.0401, 0.8620, 0.0845, 0.3778, 0.3996],\n",
    "        [0.4954, 0.0092, 0.6713, 0.8594, 0.9487],\n",
    "        [0.8101, 0.0922, 0.2033, 0.7185, 0.4588],\n",
    "        [0.3897, 0.6865, 0.5072, 0.9749, 0.0597],\n",
    "    ]\n",
    ")\n",
    "\n",
    "weight = torch.tensor(\n",
    "    [\n",
    "        [19, 2, -12, 0, 0],\n",
    "        [0, 0, 0, 0, 0],\n",
    "        [-10, 25, -8, 0, 0],\n",
    "        [21, -11, 7, 0, 0],\n",
    "        [-14, 18, -6, 0, 0],\n",
    "    ]\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "W = corr.T * weight.float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 6.0819,  0.0802, -5.9448,  0.0000,  0.0000],\n",
       "        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],\n",
       "        [-3.3820,  2.1125, -5.3704,  0.0000,  0.0000],\n",
       "        [20.4414, -4.1558,  6.0158,  0.0000,  0.0000],\n",
       "        [-1.3790,  7.1928, -5.6922,  0.0000,  0.0000]])"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Random sampling k elements"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2, 1])"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)\n",
    "torch.multinomial(weights, 2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "invalid argument 2: invalid multinomial distribution (sum of probabilities <= 0) at ../aten/src/TH/generic/THTensorRandom.cpp:374",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-58-7d7f93659649>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultinomial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mW\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m: invalid argument 2: invalid multinomial distribution (sum of probabilities <= 0) at ../aten/src/TH/generic/THTensorRandom.cpp:374"
     ]
    }
   ],
   "source": [
    "torch.multinomial((W > 0).float(), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "Wp = (W>0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = torch.nonzero(Wp, as_tuple=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0],\n",
       "        [0, 1],\n",
       "        [2, 1],\n",
       "        [3, 0],\n",
       "        [3, 2],\n",
       "        [4, 1]])"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6, 2])"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "samples.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([0, 0]),\n",
       " tensor([0, 1]),\n",
       " tensor([2, 1]),\n",
       " tensor([3, 0]),\n",
       " tensor([3, 2]),\n",
       " tensor([4, 1])]"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0, 0],\n",
      "        [0, 1],\n",
      "        [2, 1],\n",
      "        [3, 0],\n",
      "        [3, 2],\n",
      "        [4, 1]])\n",
      "[0 2 1]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0],\n",
       "        [2, 1],\n",
       "        [0, 1]])"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx = np.random.choice(range(len(samples)), 3)\n",
    "print(samples)\n",
    "print(idx)\n",
    "selected = samples[idx]\n",
    "selected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "W[list(zip(*selected))] = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.0000,  1.0000, -5.9448,  0.0000,  0.0000],\n",
       "        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],\n",
       "        [-3.3820,  1.0000, -5.3704,  0.0000,  0.0000],\n",
       "        [20.4414, -4.1558,  6.0158,  0.0000,  0.0000],\n",
       "        [-1.3790,  7.1928, -5.6922,  0.0000,  0.0000]])"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ True,  True, False, False, False],\n",
       "        [False, False, False, False, False],\n",
       "        [False,  True, False, False, False],\n",
       "        [ True, False,  True, False, False],\n",
       "        [False,  True, False, False, False]])"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Wp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
